{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Merging EBMs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will fit several Explainable Boosting Machines (EBMs) and then merge them into a single EBM model. The resulting EBM model retains the glassbox interpretability of the component EBM models, allowing us to view both global and local explanations of the merged model.\n",
    "\n",
    "This notebook can be found in our [**_examples folder_**](https://github.com/interpretml/interpret/tree/develop/docs/interpret/python/examples) on GitHub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# install interpret if not already installed\n",
    "try:\n",
    "    import interpret\n",
    "except ModuleNotFoundError:\n",
    "    !pip install --quiet interpret pandas scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from interpret import show\n",
    "\n",
    "from interpret import set_visualize_provider\n",
    "from interpret.provider import InlineProvider\n",
    "set_visualize_provider(InlineProvider())\n",
    "\n",
    "df = pd.read_csv(\n",
    "    \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\",\n",
    "    header=None)\n",
    "df.columns = [\n",
    "    \"Age\", \"WorkClass\", \"fnlwgt\", \"Education\", \"EducationNum\",\n",
    "    \"MaritalStatus\", \"Occupation\", \"Relationship\", \"Race\", \"Gender\",\n",
    "    \"CapitalGain\", \"CapitalLoss\", \"HoursPerWeek\", \"NativeCountry\", \"Income\"\n",
    "]\n",
    "X = df.iloc[:, :-1]\n",
    "y = df.iloc[:, -1]\n",
    "\n",
    "seed = 42\n",
    "np.random.seed(seed)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Explore the dataset</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret.data import ClassHistogram\n",
    "\n",
    "hist = ClassHistogram().explain_data(X_train, y_train, name = 'Train Data')\n",
    "show(hist)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Training multiple EBM models</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret.glassbox import ExplainableBoostingClassifier\n",
    "\n",
    "# Fitting multiple EBM models with different training datasets and random seeds\n",
    "seed = 1\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)\n",
    "ebm1 = ExplainableBoostingClassifier(random_state=seed)\n",
    "ebm1.fit(X_train, y_train)  \n",
    "\n",
    "seed +=10\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)\n",
    "ebm2 = ExplainableBoostingClassifier(random_state=seed)\n",
    "ebm2.fit(X_train, y_train)  \n",
    "\n",
    "seed +=10\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)\n",
    "ebm3 = ExplainableBoostingClassifier(random_state=seed)\n",
    "ebm3.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show(ebm1.explain_global(name='EBM 1'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Merging multiple trained EBM models</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from interpret.glassbox import merge_ebms\n",
    "\n",
    "merged_ebm = merge_ebms([ebm1, ebm2 , ebm3])\n",
    "show(merged_ebm.explain_global(name='Merged EBM'))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
